Skip to content

[MLIR][Memref] Improve expand-strided-metadata pass #129642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

arnab-polymage
Copy link
Contributor

Add support to obtain expanded sizes and strides of result of memref.expand_shape op having multiple dynamic dims within a reassociation set.

@llvmbot
Copy link
Member

llvmbot commented Mar 4, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Arnab Dutta (arnab-polymage)

Changes

Add support to obtain expanded sizes and strides of result of memref.expand_shape op having multiple dynamic dims within a reassociation set.


Patch is 22.62 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129642.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp (+60-94)
  • (modified) mlir/test/Dialect/MemRef/expand-strided-metadata.mlir (+55-116)
  • (modified) mlir/test/Examples/transform/ChH/full.mlir (+1-1)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
index b69cbabe0dde9..fc5c775b25a73 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp
@@ -253,14 +253,9 @@ struct ExtractStridedMetadataOpSubviewFolder
 /// Compute the expanded sizes of the given \p expandShape for the
 /// \p groupId-th reassociation group.
 /// \p origSizes hold the sizes of the source shape as values.
-/// This is used to compute the new sizes in cases of dynamic shapes.
-///
-/// sizes#i =
-///     baseSizes#groupId / product(expandShapeSizes#j,
-///                                  for j in group excluding reassIdx#i)
-/// Where reassIdx#i is the reassociation index at index i in \p groupId.
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+/// For static dim sizes, we take the values from the result type
+/// of \p expandShape. For dynamic dims, we take the values from the
+/// output_shape attribute.
 ///
 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
@@ -275,42 +270,27 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 
   unsigned groupSize = reassocGroup.size();
   SmallVector<OpFoldResult> expandedSizes(groupSize);
-
-  uint64_t productOfAllStaticSizes = 1;
-  std::optional<unsigned> dynSizeIdx;
   MemRefType expandShapeType = expandShape.getResultType();
-
-  // Fill up all the statically known sizes.
+  DenseMap<unsigned, Value> dynSizes;
+  Operation::operand_range dynOutShapes = expandShape.getOutputShape();
+  for (unsigned i = 0, dynCount = 0; i < expandShapeType.getRank(); i++) {
+    if (expandShapeType.isDynamicDim(i))
+      dynSizes[i] = dynOutShapes[dynCount++];
+  }
   for (unsigned i = 0; i < groupSize; ++i) {
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
+    unsigned index = reassocGroup[i];
+    uint64_t dimSize = expandShapeType.getDimSize(index);
     if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
+      expandedSizes[i] = dynSizes[index];
       continue;
     }
-    productOfAllStaticSizes *= dimSize;
     expandedSizes[i] = builder.getIndexAttr(dimSize);
   }
-
-  // Compute the dynamic size using the original size and all the other known
-  // static sizes:
-  // expandSize = origSize / productOfAllStaticSizes.
-  if (dynSizeIdx) {
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    expandedSizes[*dynSizeIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0.floorDiv(productOfAllStaticSizes),
-        origSizes[groupId]);
-  }
-
   return expandedSizes;
 }
 
 /// Compute the expanded strides of the given \p expandShape for the
 /// \p groupId-th reassociation group.
-/// \p origStrides and \p origSizes hold respectively the strides and sizes
-/// of the source shape as values.
-/// This is used to compute the strides in cases of dynamic shapes and/or
-/// dynamic stride for this reassociation group.
 ///
 /// strides#i =
 ///     origStrides#reassDim * product(expandShapeSizes#j, for j in
@@ -320,11 +300,8 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
 /// and expandShapeSizes#j is either:
 /// - The constant size at dimension j, derived directly from the result type of
 ///   the expand_shape op, or
-/// - An affine expression: baseSizes#reassDim / product of all constant sizes
-///   in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
-///   element.)
-///
-/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
+/// - The dynamic size  at dimension j, derived from the output_shape attribute
+///   of the expand shape op.
 ///
 /// TODO: Move this utility function directly within ExpandShapeOp. For now,
 /// this is not possible because this function uses the Affine dialect and the
@@ -334,74 +311,63 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
                                              ArrayRef<OpFoldResult> origSizes,
                                              ArrayRef<OpFoldResult> origStrides,
                                              unsigned groupId) {
-  SmallVector<int64_t, 2> reassocGroup =
-      expandShape.getReassociationIndices()[groupId];
+  auto reassocIndices = expandShape.getReassociationIndices();
+  unsigned currIdx = 0;
+  for (unsigned i = 0; i < groupId; i++)
+    currIdx += reassocIndices[i].size();
+  SmallVector<int64_t, 2> reassocGroup = reassocIndices[groupId];
   assert(!reassocGroup.empty() &&
          "Reassociation group should have at least one dimension");
 
   unsigned groupSize = reassocGroup.size();
+  llvm::errs() << "groupSize = " << groupSize << "\n";
   MemRefType expandShapeType = expandShape.getResultType();
-
-  std::optional<int64_t> dynSizeIdx;
-
   // Fill up the expanded strides, with the information we can deduce from the
   // resulting shape.
-  uint64_t currentStride = 1;
+  Location loc = expandShape.getLoc();
+  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
   SmallVector<OpFoldResult> expandedStrides(groupSize);
-  for (int i = groupSize - 1; i >= 0; --i) {
-    expandedStrides[i] = builder.getIndexAttr(currentStride);
-    uint64_t dimSize = expandShapeType.getDimSize(reassocGroup[i]);
-    if (ShapedType::isDynamic(dimSize)) {
-      assert(!dynSizeIdx && "There must be at most one dynamic size per group");
-      dynSizeIdx = i;
-      continue;
-    }
-
-    currentStride *= dimSize;
+  DenseMap<int, Value> dynSizes;
+  unsigned dynCount = 0;
+  Operation::operand_range dynOutShapes = expandShape.getOutputShape();
+  for (unsigned i = 0; i < expandShapeType.getRank(); i++) {
+    if (expandShapeType.isDynamicDim(i))
+      dynSizes[i] = dynOutShapes[dynCount++];
   }
-
-  // Collect the statically known information about the original stride.
-  Value source = expandShape.getSrc();
-  auto sourceType = cast<MemRefType>(source.getType());
-  auto [strides, offset] = sourceType.getStridesAndOffset();
-
-  OpFoldResult origStride = ShapedType::isDynamic(strides[groupId])
-                                ? origStrides[groupId]
-                                : builder.getIndexAttr(strides[groupId]);
-
-  // Apply the original stride to all the strides.
-  int64_t doneStrideIdx = 0;
-  // If we saw a dynamic dimension, we need to fix-up all the strides up to
-  // that dimension with the dynamic size.
-  if (dynSizeIdx) {
-    int64_t productOfAllStaticSizes = currentStride;
-    assert(ShapedType::isDynamic(sourceType.getDimSize(groupId)) &&
-           "We shouldn't be able to change dynamicity");
-    OpFoldResult origSize = origSizes[groupId];
-
-    AffineExpr s0 = builder.getAffineSymbolExpr(0);
-    AffineExpr s1 = builder.getAffineSymbolExpr(1);
-    for (; doneStrideIdx < *dynSizeIdx; ++doneStrideIdx) {
-      int64_t baseExpandedStride =
-          cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-              .getInt();
-      expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-          builder, expandShape.getLoc(),
-          (s0 * baseExpandedStride).floorDiv(productOfAllStaticSizes) * s1,
-          {origSize, origStride});
-    }
-  }
-
-  // Now apply the origStride to the remaining dimensions.
+  OpFoldResult origStride = origStrides[groupId];
   AffineExpr s0 = builder.getAffineSymbolExpr(0);
-  for (; doneStrideIdx < groupSize; ++doneStrideIdx) {
-    int64_t baseExpandedStride =
-        cast<IntegerAttr>(cast<Attribute>(expandedStrides[doneStrideIdx]))
-            .getInt();
-    expandedStrides[doneStrideIdx] = makeComposedFoldedAffineApply(
-        builder, expandShape.getLoc(), s0 * baseExpandedStride, {origStride});
+  AffineExpr s1 = builder.getAffineSymbolExpr(1);
+  int64_t resultOffset;
+  SmallVector<int64_t, 4> resultStrides;
+  (void)expandShapeType.getStridesAndOffset(resultStrides, resultOffset);
+  llvm::errs() << "currIdx = " << currIdx << "\n";
+  llvm::errs() << "groupId = " << groupId << "\n";
+  llvm::errs() << "resultStrides\n";
+  for (auto i: resultStrides)
+    llvm::errs() << i << " ";
+  llvm::errs() << "\n";
+  expandedStrides[groupSize - 1] =
+      !ShapedType::isDynamic(resultStrides[currIdx + groupSize - 1])
+          ? builder.getIndexAttr(resultStrides[currIdx + groupSize - 1])
+          : origStride;
+  OpFoldResult currentStride = one;
+  for (int i = groupSize - 2; i >= 0; i--) {
+    unsigned index = reassocGroup[i + 1];
+    // Multiply `currentStride` with `dimSize`.
+    currentStride =
+        expandShapeType.isDynamicDim(index)
+            ? makeComposedFoldedAffineApply(builder, loc, s0 * s1,
+                                            {currentStride, dynSizes[index]})
+            : makeComposedFoldedAffineApply(
+                  builder, loc, s0 * expandShapeType.getDimSize(index),
+                  {currentStride});
+    // Multiply `origStride` to all the strides in reassociation current group.
+    expandedStrides[i] = makeComposedFoldedAffineApply(
+        builder, loc, s0 * s1, {currentStride, origStride});
   }
-
+  llvm::errs() << "expandedStrides\n";
+  for (auto i: expandedStrides)
+    i.dump();
   return expandedStrides;
 }
 
diff --git a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
index 647731db439c0..946fbac457a74 100644
--- a/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
+++ b/mlir/test/Dialect/MemRef/expand-strided-metadata.mlir
@@ -354,67 +354,22 @@ func.func @extract_strided_metadata_of_subview_all_dynamic(
 
 // Check that we properly simplify expand_shape into:
 // reinterpret_cast(extract_strided_metadata) + <some math>
-//
-// Here we have:
-// For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
-// size 1 = 7
-// size 2 = 8
-// size 3 = 9
-// stride 0 = baseStrides#0 * 7 * 8 * 9
-//          = baseStrides#0 * 504
-// stride 1 = baseStrides#0 * 8 * 9
-//          = baseStrides#0 * 72
-// stride 2 = baseStrides#0 * 9
-// stride 3 = baseStrides#0
-//
-// For the group applying to dim1:
-// size 4 = 10
-// size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
-// size 7 = 3
-// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
-// stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
-// stride 6 = baseStrides#1 * size 7
-//          = baseStrides#1 * 3
-// stride 7 = baseStrides#1
-//
-// Base and offset are unchanged.
-//
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
-//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
-//   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
-//   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
-//   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 72)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 * 504)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 3)>
+// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 6)>
 // CHECK-LABEL: func @simplify_expand_shape
-//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
-//
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index, %[[sz1:.*]]: index
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
-//
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
-//
-//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[DYN_SIZE0]], 7, 8, 9, 10, 2, %[[DYN_SIZE6]], 3], strides: [%[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #map()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #map1()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #map2()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE3:.*]] = affine.apply #map3()[%[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #map4()[%[[STRIDES]]#1, %[[sz1]]]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #map5()[%[[STRIDES]]#1, %[[sz1]]]
+//   CHECK-DAG: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %[[BASE]] to offset: [%[[OFFSET]]], sizes: [%[[sz0]], 7, 8, 9, 10, 2, %[[sz1]], 3], strides: [%[[DYN_STRIDE2]], %[[DYN_STRIDE1]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_STRIDE5]], %[[DYN_STRIDE4]], %[[DYN_STRIDE3]], %[[STRIDES]]#1]
 //
 //   CHECK: return %[[REINTERPRET_CAST]]
 func.func @simplify_expand_shape(
@@ -525,73 +480,34 @@ func.func @extract_strided_metadata_of_expand_shape_all_static(
 // 2. We properly compute the strides affected by dynamic shapes. (When the
 //    dynamic dimension is not the first one.)
 //
-// Here we have:
-// For the group applying to dim0:
-// size 0 = baseSizes#0 / (all static sizes in that group)
-//        = baseSizes#0 / (7 * 8 * 9)
-//        = baseSizes#0 / 504
-// size 1 = 7
-// size 2 = 8
-// size 3 = 9
-// stride 0 = baseStrides#0 * 7 * 8 * 9
-//          = baseStrides#0 * 504
-// stride 1 = baseStrides#0 * 8 * 9
-//          = baseStrides#0 * 72
-// stride 2 = baseStrides#0 * 9
-// stride 3 = baseStrides#0
-//
-// For the group applying to dim1:
-// size 4 = 10
-// size 5 = 2
-// size 6 = baseSizes#1 / (all static sizes in that group)
-//        = baseSizes#1 / (10 * 2 * 3)
-//        = baseSizes#1 / 60
-// size 7 = 3
-// stride 4 = baseStrides#1 * size 5 * size 6 * size 7
-//          = baseStrides#1 * 2 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 60) * 6
-//          and since we know that baseSizes#1 is a multiple of 60:
-//          = baseStrides#1 * (baseSizes#1 / 10)
-// stride 5 = baseStrides#1 * size 6 * size 7
-//          = baseStrides#1 * (baseSizes#1 / 60) * 3
-//          = baseStrides#1 * (baseSizes#1 / 20)
-// stride 6 = baseStrides#1 * size 7
-//          = baseStrides#1 * 3
-// stride 7 = baseStrides#1
-//
 // Base and offset are unchanged.
 //
-//   CHECK-DAG: #[[$DIM0_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 504)>
-//   CHECK-DAG: #[[$DIM6_SIZE_MAP:.*]] = affine_map<()[s0] -> (s0 floordiv 60)>
-//
-//   CHECK-DAG: #[[$DIM0_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 504)>
-//   CHECK-DAG: #[[$DIM1_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 72)>
-//   CHECK-DAG: #[[$DIM2_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
-//   CHECK-DAG: #[[$DIM4_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 10) * s1)>
-//   CHECK-DAG: #[[$DIM5_STRIDE_MAP:.*]] = affine_map<()[s0, s1] -> ((s0 floordiv 20) * s1)>
-//   CHECK-DAG: #[[$DIM6_STRIDE_MAP:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 9)>
+// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 72)>
+// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 * 504)>
+// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0] -> (s0 * 3)>
+// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 3)>
+// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1] -> ((s1 * s0) * 6)>
 // CHECK-LABEL: func @extract_strided_metadata_of_expand_shape_all_dynamic
-//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32,
+//  CHECK-SAME: (%[[ARG:.*]]: memref<?x?xf32, strided<[?, ?], offset: ?>>, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %[[sz0:.*]]: index, %[[sz1:.*]]: index
 //
+//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
+//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //   CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
 //   CHECK-DAG: %[[C9:.*]] = arith.constant 9 : index
 //   CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
 //   CHECK-DAG: %[[C7:.*]] = arith.constant 7 : index
-//   CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
-//   CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
 //
 //   CHECK-DAG: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]] : memref<?x?xf32, strided<[?, ?], offset: ?>> -> memref<f32>, index, index, index, index, index
 //
-//   CHECK-DAG: %[[DYN_SIZE0:.*]] = affine.apply #[[$DIM0_SIZE_MAP]]()[%[[SIZES]]#0]
-//   CHECK-DAG: %[[DYN_SIZE6:.*]] = affine.apply #[[$DIM6_SIZE_MAP]]()[%[[SIZES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$DIM0_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$DIM1_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$DIM2_STRIDE_MAP]]()[%[[STRIDES]]#0]
-//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$DIM4_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$DIM5_STRIDE_MAP]]()[%[[SIZES]]#1, %[[STRIDES]]#1]
-//   CHECK-DAG: %[[DYN_STRIDE6:.*]] = affine.apply #[[$DIM6_STRIDE_MAP]]()[%[[STRIDES]]#1]
-
-//   CHECK: return %[[BASE]], %[[OFFSET]], %[[DYN_SIZE0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[DYN_SIZE6]], %[[C3]], %[[DYN_STRIDE0]], %[[DYN_STRIDE1]], %[[DYN_STRIDE2]], %[[STRIDES]]#0, %[[DYN_STRIDE4]], %[[DYN_STRIDE5]], %[[DYN_STRIDE6]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
+//   CHECK-DAG: %[[DYN_STRIDE0:.*]] = affine.apply #[[$MAP]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE1:.*]] = affine.apply #[[$MAP1]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE2:.*]] = affine.apply #[[$MAP2]]()[%[[STRIDES]]#0]
+//   CHECK-DAG: %[[DYN_STRIDE3:.*]] = affine.apply #[[$MAP3]]()[%[[STRIDES]]#1]
+//   CHECK-DAG: %[[DYN_STRIDE4:.*]] = affine.apply #[[$MAP4]]()[%[[STRIDES]]#1, %[[sz1]]]
+//   CHECK-DAG: %[[DYN_STRIDE5:.*]] = affine.apply #[[$MAP5]]()[%[[STRIDES]]#1, %[[sz1]]]
+//
+//   CHECK: return %[[BASE]], %[[OFFSET]], %[[sz0]], %[[C7]], %[[C8]], %[[C9]], %[[C10]], %[[C2]], %[[sz1]], %[[C3]], %[[DYN_STRIDE2]], %[[DYN_STRIDE1]], %[[DYN_STRIDE0]], %[[STRIDES]]#0, %[[DYN_STRIDE5]], %[[DYN_STRIDE4]], %[[DYN_STRIDE3]], %[[STRIDES]]#1 : memref<f32>, index, index, index, index, index, index, index, index, index, index, index, index, index
 func.func @extract_strided_metadata_of_expand_shape_all_dynamic(
     %base: memref<?x?xf32, strided<[?,?], offset:?>>,
   ...
[truncated]

Copy link

github-actions bot commented Mar 4, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@arnab-polymage arnab-polymage changed the title Improve expand-strided-metadata pass [MLIR][Memref] Improve expand-strided-metadata pass Mar 4, 2025
@arnab-polymage arnab-polymage force-pushed the ornib/improve_expand_strided_metadata branch 2 times, most recently from ddf4d0c to be2719c Compare March 4, 2025 08:23
@bondhugula bondhugula requested a review from qcolombet March 4, 2025 08:31
@arnab-polymage arnab-polymage force-pushed the ornib/improve_expand_strided_metadata branch from be2719c to 165dfaf Compare March 4, 2025 08:56
Add support to obtain expanded sizes and strides of
result of memref.expand_shape op having multiple dynamic
dims within a reassociation set.
@arnab-polymage arnab-polymage force-pushed the ornib/improve_expand_strided_metadata branch from 165dfaf to b755dc0 Compare March 4, 2025 08:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants